{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "OzG-MRwD8HgI",
    "outputId": "27821ee5-c0c7-430e-e532-05caa599b621"
   },
   "outputs": [],
   "source": [
    "# For the 'InstanceNormalization' layer\n",
    "!pip install --upgrade tensorflow_addons\n",
    "# For the dataset\n",
    "!pip install --upgrade tensorflow_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eWBg9-PI8SZG"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.initializers import RandomNormal\n",
    "from tensorflow.keras.layers import (Activation, Concatenate, Conv2D,\n",
    "                                     Conv2DTranspose, Input, LeakyReLU)\n",
    "from tensorflow.keras.losses import BinaryCrossentropy\n",
    "from tensorflow.keras.models import Model, Sequential\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow_addons.layers import InstanceNormalization\n",
    "\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 309
    },
    "colab_type": "code",
    "id": "B6M8YY-58Sb6",
    "outputId": "bc246112-8e3c-4660-8bc4-b7451e7ae296"
   },
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "data, metadata = tfds.load('cycle_gan/monet2photo', with_info=True, as_supervised=True)\n",
    "\n",
    "train_x, train_y, test_x, test_y = data['trainA'], data['trainB'], data['testA'], data['testB']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tq1rsy-o8SeN"
   },
   "outputs": [],
   "source": [
    "# Settings\n",
    "epochs = 50\n",
    "\n",
    "LAMBDA = 10\n",
    "\n",
    "img_rows, img_cols, channels = 256, 256, 3\n",
    "weight_initializer = RandomNormal(stddev=0.02)\n",
    "\n",
    "gen_g_optimizer = gen_f_optimizer = Adam(lr=0.0002, beta_1=0.5)\n",
    "dis_x_optimizer = dis_y_optimizer = Adam(lr=0.0002, beta_1=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Hq6Knn_C8SgY"
   },
   "outputs": [],
   "source": [
    "# Normalize images to [-1, 1] and reshape\n",
    "def preprocess_image(image, _):\n",
    "    return tf.reshape(tf.cast(tf.image.resize(image, (int(img_rows), int(img_cols))), tf.float32) / 127.5 - 1, (1, img_rows, img_cols, channels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wrrpsmCE8SjE"
   },
   "outputs": [],
   "source": [
    "# Map the normalization onto the dataset\n",
    "train_x = train_x.map(preprocess_image)\n",
    "train_y = train_y.map(preprocess_image)\n",
    "test_x = test_x.map(preprocess_image)\n",
    "test_y = test_y.map(preprocess_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MZG9um_J8SmD"
   },
   "outputs": [],
   "source": [
    "# \"Ck denotes a 4 × 4 Convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2\n",
    "def Ck(input, k, use_instancenorm=True):\n",
    "    block = Conv2D(k, (4, 4), strides=2, padding='same', kernel_initializer=weight_initializer)(input)\n",
    "    if use_instancenorm:\n",
    "        block = InstanceNormalization(axis=-1)(block)\n",
    "    block = LeakyReLU(0.2)(block)\n",
    "\n",
    "    return block\n",
    "\n",
    "# C64, C128, C256, C512\n",
    "def discriminator():\n",
    "    dis_input = Input(shape=(img_rows, img_cols, channels))\n",
    "\n",
    "    d = Ck(dis_input, 64, False)\n",
    "    d = Ck(d, 128)\n",
    "    d = Ck(d, 256)\n",
    "    d = Ck(d, 512)\n",
    "\n",
    "    d = Conv2D(1, (4, 4), padding='same', kernel_initializer=weight_initializer)(d)\n",
    "\n",
    "    return Model(dis_input, d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ByAGo-Uc8SoS"
   },
   "outputs": [],
   "source": [
    "# \"dk denotes a 3×3 Convolution-InstanceNorm-ReLU with k filters and stride 2\"\n",
    "def dk(k, use_instancenorm=True):\n",
    "    block = Sequential()\n",
    "    block.add(Conv2D(k, (3, 3), strides=2, padding='same', kernel_initializer=weight_initializer))\n",
    "    if use_instancenorm:\n",
    "        block.add(InstanceNormalization(axis=-1))\n",
    "    block.add(Activation('relu'))\n",
    "\n",
    "    return block\n",
    "\n",
    "# \"uk denotes a 3×3 fractional-strided-ConvolutionInstanceNorm-ReLU layer with k filters and stride ½\"\n",
    "def uk(k):\n",
    "    block = Sequential()\n",
    "    block.add(Conv2DTranspose(k, (3, 3), strides=2, padding='same', kernel_initializer=weight_initializer))\n",
    "    block.add(InstanceNormalization(axis=-1))\n",
    "    block.add(Activation('relu'))\n",
    "\n",
    "    return block\n",
    "\n",
    "def generator():\n",
    "    gen_input = Input(shape=(img_rows, img_cols, channels))\n",
    "    \n",
    "    # Layers for the encoder part of the model\n",
    "    encoder_layers = [\n",
    "        dk(64, False),\n",
    "        dk(128),\n",
    "        dk(256),\n",
    "        dk(512),\n",
    "        dk(512),\n",
    "        dk(512),\n",
    "        dk(512),\n",
    "        dk(512)\n",
    "    ]\n",
    "\n",
    "    # Layers for the decoder part of the model\n",
    "    decoder_layers = [\n",
    "        uk(512),\n",
    "        uk(512),\n",
    "        uk(512),\n",
    "        uk(512),\n",
    "        uk(256),\n",
    "        uk(128),\n",
    "        uk(64)\n",
    "    ]\n",
    "\n",
    "    gen = gen_input\n",
    "\n",
    "    # Add all the encoder layers, and keep track of them for skip connections\n",
    "    skips = []\n",
    "    for layer in encoder_layers:\n",
    "        gen = layer(gen)\n",
    "        skips.append(gen)\n",
    "    \n",
    "    skips = skips[::-1][1:] # Reverse for looping and get rid of the layer that directly connects to decoder\n",
    "\n",
    "    # Add all the decoder layers and skip connections\n",
    "    for skip_layer, layer in zip(skips, decoder_layers):\n",
    "        gen = layer(gen)\n",
    "        gen = Concatenate()([gen, skip_layer])\n",
    "\n",
    "    \n",
    "    # Final layer\n",
    "    gen = Conv2DTranspose(channels, (3, 3), strides=2, padding='same', kernel_initializer=weight_initializer, activation='tanh')(gen)\n",
    "    \n",
    "    # Compose model\n",
    "    return Model(gen_input, gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7ZZOaJKkIAT2"
   },
   "outputs": [],
   "source": [
    "# Define the models\n",
    "generator_g = generator()\n",
    "generator_f = generator()\n",
    "\n",
    "discriminator_x = discriminator()\n",
    "discriminator_y = discriminator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8s-DdUc68Sqs"
   },
   "outputs": [],
   "source": [
    " # Losses\n",
    "loss = BinaryCrossentropy(from_logits=True)\n",
    "\n",
    "# Measures how close to one real images are rated, and how close to zero fake images are rated\n",
    "def discriminator_loss(real, generated):\n",
    "    # Multiplied by 0.5 so that it will train at half-speed\n",
    "    return (loss(tf.ones_like(real), real) + loss(tf.zeros_like(generated), generated)) * 0.5\n",
    "\n",
    "# Measures how real the discriminator believes the fake image is\n",
    "def gen_loss(validity):\n",
    "    return loss(tf.ones_like(validity), validity)\n",
    "\n",
    "# Measures similarity of two images.  Used for cycle and identity loss\n",
    "def image_similarity(image1, image2):\n",
    "    return tf.reduce_mean(tf.abs(image1 - image2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WYUJS3CF8Ss7"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def step(real_x, real_y):\n",
    "    with tf.GradientTape(persistent=True) as tape:\n",
    "        # Setup Dy loss\n",
    "        fake_y = generator_g(real_x, training=True)\n",
    "        gen_g_validity = discriminator_y(fake_y, training=True)\n",
    "        dis_y_loss = discriminator_loss(discriminator_y(real_y, training=True), gen_g_validity)\n",
    "\n",
    "        with tape.stop_recording():\n",
    "            discriminator_y_gradients = tape.gradient(dis_y_loss, discriminator_y.trainable_variables)\n",
    "            dis_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))\n",
    "\n",
    "        # Setup Dx loss\n",
    "        fake_x = generator_f(real_y, training=True)\n",
    "        gen_f_validity = discriminator_x(fake_x, training=True)\n",
    "        dis_x_loss = discriminator_loss(discriminator_x(real_x, training=True), gen_f_validity)\n",
    "\n",
    "        with tape.stop_recording():\n",
    "            discriminator_x_gradients = tape.gradient(dis_x_loss, discriminator_x.trainable_variables)\n",
    "            dis_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))\n",
    "\n",
    "        # Setup adversarial losses\n",
    "        gen_g_adv_loss = gen_loss(gen_g_validity)\n",
    "        gen_f_adv_loss = gen_loss(gen_f_validity)\n",
    "\n",
    "        # Setup cycle losses\n",
    "        cyc_x = generator_f(fake_y, training=True)\n",
    "        cyc_x_loss = image_similarity(real_x, cyc_x)\n",
    "\n",
    "        cyc_y = generator_g(fake_x, training=True)\n",
    "        cyc_y_loss =  image_similarity(real_y, cyc_y)\n",
    "\n",
    "        # Setup identity losses\n",
    "        id_x = generator_f(real_x, training=True)\n",
    "        id_x_loss = image_similarity(real_x, id_x)\n",
    "\n",
    "        id_y = generator_g(real_y, training=True)\n",
    "        id_y_loss = image_similarity(real_y, id_y)\n",
    "\n",
    "        # Finalize generator losses and calc gradients\n",
    "        gen_g_loss = gen_g_adv_loss + (cyc_x_loss + cyc_y_loss) * LAMBDA + id_y_loss * 0.5*LAMBDA\n",
    "        gen_f_loss = gen_f_adv_loss + (cyc_x_loss + cyc_y_loss) * LAMBDA + id_x_loss * 0.5*LAMBDA\n",
    "\n",
    "        with tape.stop_recording():\n",
    "            generator_g_gradients = tape.gradient(gen_g_loss, generator_g.trainable_variables)\n",
    "            gen_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))\n",
    "\n",
    "            generator_f_gradients = tape.gradient(gen_f_loss, generator_f.trainable_variables)\n",
    "            gen_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uO62UsfgFYn1"
   },
   "outputs": [],
   "source": [
    "def generate_images():\n",
    "    # Sample images\n",
    "    x = next(iter(test_x.shuffle(1000))).numpy()\n",
    "    y = next(iter(test_y.shuffle(1000))).numpy()\n",
    "    \n",
    "    # Get predictions for those images\n",
    "    y_hat = generator_g.predict(x.reshape((1, img_rows, img_cols, channels)))\n",
    "    x_hat = generator_f.predict(y.reshape((1, img_rows, img_cols, channels)))\n",
    "    \n",
    "    plt.figure(figsize=(12, 12))\n",
    "\n",
    "    images = [x[0], y_hat[0], y[0], x_hat[0]]\n",
    "\n",
    "    for i in range(4):\n",
    "        plt.subplot(2, 2, i+1)\n",
    "        plt.imshow(images[i] * 0.5 + 0.5)\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "gPM73sAi8Svs",
    "outputId": "519b49ca-51aa-478c-8d6f-bfaeeaec731f",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Manually loop through epochs\n",
    "for epoch in range(epochs):\n",
    "    print('Epoch: {}'.format(epoch))\n",
    "    start = time.time()\n",
    "\n",
    "    # Each batch\n",
    "    for k, (real_x, real_y) in enumerate(tf.data.Dataset.zip((train_x, train_y))):\n",
    "        if k % 100 == 0: print(k)\n",
    "        # Train step\n",
    "        step(tf.reshape(real_x, (1, img_rows, img_cols, channels)), tf.reshape(real_y, (1, img_rows, img_cols, channels)))\n",
    "    \n",
    "    # View progress\n",
    "    generate_images()\n",
    "    print('Time taken: {}'.format(time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NxvtP4Zn8Sxp"
   },
   "outputs": [],
   "source": [
    "for _ in range(10):\n",
    "    generate_images()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eYSCjiD68Sz_"
   },
   "outputs": [],
   "source": [
    "generator_g.save('generator_g.h5')\n",
    "generator_f.save('generator_f.h5')\n",
    "discriminator_x.save('discriminator_x.h5')\n",
    "discriminator_y.save('discriminator_y.h5')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "FINAL.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}